/*
*  This file is part of ygg-brute
*  Copyright (c) 2020 ygg-brute authors
*  See LICENSE for licensing information
*/

KERNEL_SPEC(compute_uv_kernel_0, BLOCK_SIZE)
(
    GLOBAL uint32_t* zs,
    GLOBAL uint32_t* efs,
    GLOBAL const uint32_t* __restrict__ midstate_xy,
    GLOBAL const uint32_t* __restrict__ addends_xy,
    GLOBAL const uint64_t* __restrict__ seq
) {
    const size_t gid = global_id();
    const size_t lid = local_id();
    const size_t gsize = global_size();
    const size_t efs_n = gsize * 2;

    LOCAL FieldElement q_xy;

    FieldElement c, e, f, xy, z;

    if(lid == 0) {
        fe_copy(q_xy, addends_xy + FE_SIZE * (*seq));
    }

    SYNCTHREADS();

    col_load(xy, midstate_xy, FE_SIZE, gsize, gid);

    fe_mul(c, xy, q_xy);

    fe_sub(e, two, c);
    fe_add(f, two, c);
    fe_mul(z, e, f);

    fe_store_global(zs, z);

    fe_store(efs, efs_n, e, 2 * gid + 1);
    fe_store(efs, efs_n, f, 2 * gid);
}

KERNEL_SPEC(compute_uv_kernel_1, BLOCK_SIZE)
(
    GLOBAL uint32_t* us,
    GLOBAL uint32_t* vs,
    GLOBAL const uint32_t* __restrict__ zs,
    GLOBAL const uint32_t* __restrict__ efs,
    GLOBAL const uint32_t* __restrict__ midstate_ypx,
    GLOBAL const uint32_t* __restrict__ midstate_ymx,
    GLOBAL const uint32_t* __restrict__ addends_ypx,
    GLOBAL const uint32_t* __restrict__ addends_ymx,
    GLOBAL const uint64_t* __restrict__ seq
) {
    const size_t gid = global_id();
    const size_t lid = local_id();
    const size_t gsize = global_size();
    const size_t lsize = local_size();

    LOCAL FieldElement q_ypx, q_ymx;

#ifdef __CUDACC__
    extern LOCAL uint32_t abms[];
#else
    LOCAL uint32_t abms[BLOCK_SIZE * FE_SIZE];
#endif

    FieldElement a, b, d, ef, y, z;

    if(lid == 0) {
        fe_copy(q_ypx, addends_ypx + FE_SIZE * (*seq));
        fe_copy(q_ymx, addends_ymx + FE_SIZE * (*seq));
    }

    if(lid % 2 == 0) {
        for(size_t i = 0; i < FE_SIZE; ++i) {
            abms[lsize * i + lid] = midstate_ypx[(gsize / 2) * i + gid / 2];
        }
    } else {
        for(size_t i = 0; i < FE_SIZE; ++i) {
            abms[lsize * i + lid] = midstate_ymx[(gsize / 2) * i + gid / 2];
        }
    }

    SYNCTHREADS();

    col_load(d, abms, FE_SIZE, lsize, 2 * (lid / 2) + (lid % 2));
    fe_mul(a, d, q_ypx);
    col_load(d, abms, FE_SIZE, lsize, 2 * (lid / 2) + ((lid + 1)) % 2);
    fe_mul(b, d, q_ymx);
    fe_add(d, a, b);

    fe_load_global(ef, efs);

    fe_mul(y, d, ef);

    fe_load(z, zs, gsize / 2, gid / 2);

    FieldElement u, v;
    fe_sub(u, z, y);
    fe_add(v, z, y);

    fe_store_global(us, u);
    fe_store_global(vs, v);
}